{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune SAM on customized datasets (3D example)\n",
    "1. Prepare original 3D images `data/FLARE22Train/` (Download link:https://zenodo.org/record/7860267) \n",
    "2. Run `pre_CT.py` for pre-processing. Expected output: `./data/Npz_files/CT_Abd-Gallbladder_`\n",
    "3. Start this fine-tuning tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% set up environment\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "join = os.path.join\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import monai\n",
    "from segment_anything import SamPredictor, sam_model_registry\n",
    "from segment_anything.utils.transforms import ResizeLongestSide\n",
    "from utils.SurfaceDice import compute_dice_coefficient\n",
    "# set seeds\n",
    "torch.manual_seed(2023)\n",
    "np.random.seed(2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% create a dataset class to load npz data and return back image embeddings and ground truth\n",
    "class NpzDataset(Dataset): \n",
    "    def __init__(self, data_root):\n",
    "        self.data_root = data_root\n",
    "        self.npz_files = sorted(os.listdir(self.data_root)) \n",
    "        self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]\n",
    "        # this implementation is ugly but it works (and is also fast for feeding data to GPU) if your server has enough RAM\n",
    "        # as an alternative, you can also use a list of npy files and load them one by one\n",
    "        self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])\n",
    "        self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])\n",
    "        print(f\"{self.img_embeddings.shape=}, {self.ori_gts.shape=}\")\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.ori_gts.shape[0]\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img_embed = self.img_embeddings[index]\n",
    "        gt2D = self.ori_gts[index]\n",
    "        y_indices, x_indices = np.where(gt2D > 0)\n",
    "        x_min, x_max = np.min(x_indices), np.max(x_indices)\n",
    "        y_min, y_max = np.min(y_indices), np.max(y_indices)\n",
    "        # add perturbation to bounding box coordinates\n",
    "        H, W = gt2D.shape\n",
    "        x_min = max(0, x_min - np.random.randint(0, 20))\n",
    "        x_max = min(W, x_max + np.random.randint(0, 20))\n",
    "        y_min = max(0, y_min - np.random.randint(0, 20))\n",
    "        y_max = min(H, y_max + np.random.randint(0, 20))\n",
    "        bboxes = np.array([x_min, y_min, x_max, y_max])\n",
    "        # convert img embedding, mask, bounding box to torch tensor\n",
    "        return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% test dataset class and dataloader\n",
    "npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n",
    "demo_dataset = NpzDataset(npz_tr_path)\n",
    "demo_dataloader = DataLoader(demo_dataset, batch_size=8, shuffle=True)\n",
    "for img_embed, gt2D, bboxes in demo_dataloader:\n",
    "    # img_embed: (B, 256, 64, 64), gt2D: (B, 1, 256, 256), bboxes: (B, 4)\n",
    "    print(f\"{img_embed.shape=}, {gt2D.shape=}, {bboxes.shape=}\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% set up model for fine-tuning \n",
    "# train data path\n",
    "npz_tr_path = 'data/Npz_files/CT_Abd-Gallbladder/train'\n",
    "work_dir = './work_dir'\n",
    "task_name = 'CT_Abd-Gallbladder'\n",
    "# prepare SAM model\n",
    "model_type = 'vit_b'\n",
    "checkpoint = 'work_dir/SAM/sam_vit_b_01ec64.pth'\n",
    "device = 'cuda:0'\n",
    "model_save_path = join(work_dir, task_name)\n",
    "os.makedirs(model_save_path, exist_ok=True)\n",
    "sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n",
    "sam_model.train()\n",
    "\n",
    "# Set up the optimizer, hyperparameter tuning will improve performance here\n",
    "optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)\n",
    "seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% train\n",
    "num_epochs = 100\n",
    "losses = []\n",
    "best_loss = 1e10\n",
    "train_dataset = NpzDataset(npz_tr_path)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "for epoch in range(num_epochs):\n",
    "    epoch_loss = 0\n",
    "    # train\n",
    "    for step, (image_embedding, gt2D, boxes) in enumerate(tqdm(train_dataloader)):\n",
    "        # do not compute gradients for image encoder and prompt encoder\n",
    "        with torch.no_grad():\n",
    "            # convert box to 1024x1024 grid\n",
    "            box_np = boxes.numpy()\n",
    "            sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)\n",
    "            box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))\n",
    "            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)\n",
    "            if len(box_torch.shape) == 2:\n",
    "                box_torch = box_torch[:, None, :] # (B, 1, 4)\n",
    "            # get prompt embeddings \n",
    "            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n",
    "                points=None,\n",
    "                boxes=box_torch,\n",
    "                masks=None,\n",
    "            )\n",
    "        # predicted masks\n",
    "        mask_predictions, _ = sam_model.mask_decoder(\n",
    "            image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n",
    "            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n",
    "            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n",
    "            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n",
    "            multimask_output=False,\n",
    "          )\n",
    "\n",
    "        loss = seg_loss(mask_predictions, gt2D.to(device))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "    \n",
    "    epoch_loss /= step\n",
    "    losses.append(epoch_loss)\n",
    "    print(f'EPOCH: {epoch}, Loss: {epoch_loss}')\n",
    "    # save the latest model checkpoint\n",
    "    torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_latest.pth'))\n",
    "    # save the best model\n",
    "    if epoch_loss < best_loss:\n",
    "        best_loss = epoch_loss\n",
    "        torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_best.pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot loss\n",
    "plt.plot(losses)\n",
    "plt.title('Dice + Cross Entropy Loss')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.show() # comment this line if you are running on a server\n",
    "plt.savefig(join(model_save_path, 'train_loss.png'))\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% compare the segmentation results between the original SAM model and the fine-tuned model\n",
    "# load the original SAM model\n",
    "ori_sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)\n",
    "ori_sam_predictor = SamPredictor(ori_sam_model)\n",
    "npz_ts_path = 'data/Npz_files/CT_Abd-Gallbladder/test'\n",
    "test_npzs = sorted(os.listdir(npz_ts_path))\n",
    "# random select a test case\n",
    "npz_idx = np.random.randint(0, len(test_npzs))\n",
    "npz = np.load(join(npz_ts_path, test_npzs[npz_idx]))\n",
    "imgs = npz['imgs']\n",
    "gts = npz['gts']\n",
    "\n",
    "def get_bbox_from_mask(mask):\n",
    "    '''Returns a bounding box from a mask'''\n",
    "    y_indices, x_indices = np.where(mask > 0)\n",
    "    x_min, x_max = np.min(x_indices), np.max(x_indices)\n",
    "    y_min, y_max = np.min(y_indices), np.max(y_indices)\n",
    "    # add perturbation to bounding box coordinates\n",
    "    H, W = mask.shape\n",
    "    x_min = max(0, x_min - np.random.randint(0, 20))\n",
    "    x_max = min(W, x_max + np.random.randint(0, 20))\n",
    "    y_min = max(0, y_min - np.random.randint(0, 20))\n",
    "    y_max = min(H, y_max + np.random.randint(0, 20))\n",
    "\n",
    "    return np.array([x_min, y_min, x_max, y_max])\n",
    "\n",
    "ori_sam_segs = []\n",
    "medsam_segs = []\n",
    "bboxes = []\n",
    "for img, gt in zip(imgs, gts):\n",
    "    bbox = get_bbox_from_mask(gt)\n",
    "    bboxes.append(bbox)\n",
    "    # predict the segmentation mask using the original SAM model\n",
    "    ori_sam_predictor.set_image(img)\n",
    "    ori_sam_seg, _, _ = ori_sam_predictor.predict(point_coords=None, box=bbox, multimask_output=False)\n",
    "    ori_sam_segs.append(ori_sam_seg[0])\n",
    "    \n",
    "    # predict the segmentation mask using the fine-tuned model\n",
    "    H, W = img.shape[:2]\n",
    "    resize_img = sam_trans.apply_image(img)\n",
    "    resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)\n",
    "    input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)\n",
    "    with torch.no_grad():\n",
    "        image_embedding = sam_model.image_encoder(input_image.to(device)) # (1, 256, 64, 64)\n",
    "        # convert box to 1024x1024 grid\n",
    "        bbox = sam_trans.apply_boxes(bbox, (H, W))\n",
    "        box_torch = torch.as_tensor(bbox, dtype=torch.float, device=device)\n",
    "        if len(box_torch.shape) == 2:\n",
    "            box_torch = box_torch[:, None, :] # (B, 1, 4)\n",
    "        \n",
    "        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(\n",
    "            points=None,\n",
    "            boxes=box_torch,\n",
    "            masks=None,\n",
    "        )\n",
    "        medsam_seg_prob, _ = sam_model.mask_decoder(\n",
    "            image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)\n",
    "            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)\n",
    "            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)\n",
    "            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)\n",
    "            multimask_output=False,\n",
    "            )\n",
    "        medsam_seg_prob = torch.sigmoid(medsam_seg_prob)\n",
    "        # convert soft mask to hard mask\n",
    "        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()\n",
    "        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)\n",
    "        medsam_segs.append(medsam_seg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% compute the DSC score\n",
    "ori_sam_segs = np.stack(ori_sam_segs, axis=0)\n",
    "medsam_segs = np.stack(medsam_segs, axis=0)\n",
    "ori_sam_dsc = compute_dice_coefficient(gts>0, ori_sam_segs>0)\n",
    "medsam_dsc = compute_dice_coefficient(gts>0, medsam_segs>0)\n",
    "print('Original SAM DSC: {:.4f}'.format(ori_sam_dsc), 'MedSAM DSC: {:.4f}'.format(medsam_dsc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% visualize the segmentation results of the middle slice\n",
    "# visualization functions\n",
    "# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\n",
    "# change color to avoid red and green\n",
    "def show_mask(mask, ax, random_color=False):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([251/255, 252/255, 30/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "    \n",
    "def show_box(box, ax):\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))    \n",
    "\n",
    "\n",
    "img_id = int(imgs.shape[0]/2)  # np.random.randint(imgs.shape[0])\n",
    "_, axs = plt.subplots(1, 3, figsize=(25, 25))\n",
    "axs[0].imshow(imgs[img_id])\n",
    "show_mask(gts[img_id], axs[0])\n",
    "# show_box(box_np[img_id], axs[0])\n",
    "# axs[0].set_title('Mask with Tuned Model', fontsize=20)\n",
    "axs[0].axis('off')\n",
    "\n",
    "axs[1].imshow(imgs[img_id])\n",
    "show_mask(ori_sam_segs[img_id], axs[1])\n",
    "show_box(bboxes[img_id], axs[1])\n",
    "# add text to image to show dice score\n",
    "axs[1].text(0.5, 0.5, 'SAM DSC: {:.4f}'.format(ori_sam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n",
    "# axs[1].set_title('Mask with Untuned Model', fontsize=20)\n",
    "axs[1].axis('off')\n",
    "\n",
    "axs[2].imshow(imgs[img_id])\n",
    "show_mask(medsam_segs[img_id], axs[2])\n",
    "show_box(bboxes[img_id], axs[2])\n",
    "# add text to image to show dice score\n",
    "axs[2].text(0.5, 0.5, 'MedSAM DSC: {:.4f}'.format(medsam_dsc), fontsize=30, horizontalalignment='left', verticalalignment='top', color='yellow')\n",
    "# axs[2].set_title('Ground Truth', fontsize=20)\n",
    "axs[2].axis('off')\n",
    "plt.show()  \n",
    "plt.subplots_adjust(wspace=0.01, hspace=0)\n",
    "# save plot\n",
    "# plt.savefig(join(model_save_path, test_npzs[npz_idx].split('.npz')[0] + str(img_id).zfill(3) + '.png'), bbox_inches='tight', dpi=300)\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "medsam-demo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
